{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aaf8878-307d-49c6-a21b-90607770ab0d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install zarr\n",
    "!pip install scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb3fb8cb-5e98-4af3-ab6e-51c33852e95b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import dask\n",
    "import dask.array as da\n",
    "import dask.dataframe as dd\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pyarrow as pa\n",
    "\n",
    "from os.path import join\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c75ee8b-373f-44a8-83c3-a8ec7f0b3fd4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from dask.distributed import Client, LocalCluster\n",
    "\n",
    "\n",
    "cluster = LocalCluster(n_workers=5)  # assume 20 cores on LRZ -> 5 workers with 4 threads each\n",
    "client = Client(cluster)\n",
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09d5274a-e048-441a-ab06-d162770febe5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "NORMALIZATION = 'sf-log1p'\n",
    "\n",
    "# sf-log1p -> normalize to 10000 counts + log1p transform data\n",
    "# raw -> don't normalize data\n",
    "\n",
    "assert NORMALIZATION in ['sf-log1p', 'raw']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63a4aea0-0af9-47c6-9ec6-5743aeae6375",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import csc_matrix, csr_matrix, issparse\n",
    "from sklearn.utils import sparsefuncs\n",
    "\n",
    "\n",
    "def sf_normalize(X):\n",
    "    X = X.copy()\n",
    "    counts = np.array(X.sum(axis=1))\n",
    "    # avoid zero devision error\n",
    "    counts += counts == 0.\n",
    "    # normalize to 10000. counts\n",
    "    scaling_factor = 10000. / counts\n",
    "\n",
    "    if issparse(X):\n",
    "        sparsefuncs.inplace_row_scale(X, scaling_factor)\n",
    "    else:\n",
    "        np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)\n",
    "\n",
    "    return X\n",
    "\n",
    "\n",
    "def sf_log1p_norm(x):\n",
    "    x = sf_normalize(x)\n",
    "    return np.log1p(x).astype('f4')\n",
    "\n",
    "\n",
    "def preprocess_count_matrix(x, normalization):\n",
    "    if normalization == 'sf-log1p':\n",
    "        return x.map_blocks(sf_log1p_norm, dtype='f4')\n",
    "    elif normalization == 'raw':\n",
    "        return x\n",
    "    else:\n",
    "        raise ValueError(f'NORMALIZATION has to be in [\"sf-log1p\", \"raw\"]')\n",
    "\n",
    "\n",
    "@dask.delayed\n",
    "def convert_to_dataframe(x, start, end):\n",
    "    return pd.DataFrame(\n",
    "        {'X': [arr.squeeze().astype('f4') for arr in np.vsplit(x, x.shape[0])]},\n",
    "        index=pd.RangeIndex(start, end)\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5851124f-a8ab-4747-9cc5-1254088c196d",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9ef276-f4d0-40b4-bb68-c886633a66da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "DATA_PATH = '/mnt/dssfs02/cxg_census/data_2023_05_15'\n",
    "OUT_PATH = f'/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_{NORMALIZATION}'\n",
    "\n",
    "os.makedirs(OUT_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e78af09-08ec-4c5b-9b8f-90e3cc85f9e5",
   "metadata": {},
   "source": [
    "## Copy var dataframe + norm data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d13a0bc-c223-4760-9f7e-686844c255c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "shutil.copy(join(DATA_PATH, 'train', 'var.parquet'), join(OUT_PATH, 'var.parquet'));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8401d2a-0dcb-4980-b712-0ce28c60679c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# only run if NORMALIZATION == 'sf-quantile'\n",
    "!cp -r {join(DATA_PATH, 'norm')} {join(OUT_PATH, 'norm')}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01ed62d7-f3ed-4470-8f16-786145f2dafe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f279a9d4-4146-4b76-b276-bef7c29e1103",
   "metadata": {},
   "source": [
    "## Create lookup tables for categorical variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a07e899-81ce-40cf-a5f9-5c3feb33f4ac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pandas import testing as tm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddf6b670-6ba1-4235-9fde-05947fa520bd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "obs_train = pd.read_parquet(join(DATA_PATH, 'train', 'obs.parquet')).reset_index(drop=True)\n",
    "obs_val = pd.read_parquet(join(DATA_PATH, 'val', 'obs.parquet')).reset_index(drop=True)\n",
    "obs_test = pd.read_parquet(join(DATA_PATH, 'test', 'obs.parquet')).reset_index(drop=True)\n",
    "\n",
    "obs = pd.concat([obs_train, obs_val, obs_test])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55d7697a-57f2-4615-949a-f4840896e2d1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "cols_train = obs_train.columns.tolist()\n",
    "assert cols_train == obs_val.columns.tolist()\n",
    "assert cols_train == obs_test.columns.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5f4941e-000a-4264-962b-9b2347c2a749",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for col in cols_train:\n",
    "    if obs[col].dtype.name == 'category':\n",
    "        obs[col] = obs[col].cat.remove_unused_categories()\n",
    "\n",
    "\n",
    "for col in cols_train:\n",
    "    if obs[col].dtype.name == 'category':\n",
    "        categories = list(obs[col].cat.categories)\n",
    "        obs_train[col] = pd.Categorical(obs_train[col], categories, ordered=False)\n",
    "        obs_val[col] = pd.Categorical(obs_val[col], categories, ordered=False)\n",
    "        obs_test[col] = pd.Categorical(obs_test[col], categories, ordered=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee8fedbb-0e17-4e48-8db5-3f9db7218a69",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "lookup_path = join(OUT_PATH, 'categorical_lookup')\n",
    "os.makedirs(lookup_path)\n",
    "\n",
    "for col in cols_train:\n",
    "    if obs_train[col].dtype.name == 'category':\n",
    "        cats_train = pd.Series(dict(enumerate(obs_train[col].cat.categories))).to_frame().rename(columns={0: 'label'})\n",
    "        cats_val = pd.Series(dict(enumerate(obs_val[col].cat.categories))).to_frame().rename(columns={0: 'label'})\n",
    "        cats_test = pd.Series(dict(enumerate(obs_test[col].cat.categories))).to_frame().rename(columns={0: 'label'})\n",
    "\n",
    "        tm.assert_frame_equal(cats_train, cats_val)\n",
    "        tm.assert_frame_equal(cats_train, cats_test)\n",
    "\n",
    "        cats_train.to_parquet(join(lookup_path, f'{col}.parquet'), index=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "398a8a0a-007b-4058-8ead-b90085bb412d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# only use integer labels from now on\n",
    "for col in cols_train:\n",
    "    if obs_train[col].dtype.name == 'category':\n",
    "        obs_train[col] = obs_train[col].cat.codes.astype('i8')\n",
    "        obs_val[col] = obs_val[col].cat.codes.astype('i8')\n",
    "        obs_test[col] = obs_test[col].cat.codes.astype('i8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5268e753-5aed-460a-814d-83fa9e2d17c1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "obs_dict = {'train': obs_train, 'val': obs_val, 'test': obs_test}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfcbbc45-4c01-4f36-aeea-98c2ad0689dc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.utils.class_weight import compute_class_weight\n",
    "\n",
    "# calculate and save class weights\n",
    "class_weights = compute_class_weight('balanced', classes=np.unique(obs_train['cell_type']), y=obs_train['cell_type'])\n",
    "\n",
    "with open(join(OUT_PATH, 'class_weights.npy'), 'wb') as f:\n",
    "    np.save(f, class_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13bfdecc-e1ec-42a3-a134-4c0c59064ce4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "32a708b7-fde4-4f3c-b9d0-c1ba820a4a5c",
   "metadata": {},
   "source": [
    "## Write store"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "238f9e13-ce0d-498e-b832-fa824737c625",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "CHUNK_SIZE = 32768\n",
    "ROW_GROUP_SIZE = 1024\n",
    "\n",
    "\n",
    "for split in ['train', 'val', 'test']:\n",
    "    X = preprocess_count_matrix(da.from_zarr(join(DATA_PATH, split, 'zarr'), 'X'), NORMALIZATION)\n",
    "    obs_ = obs_dict[split]\n",
    "    # cut off samples that all row groups are full\n",
    "    n_samples = X.shape[0]\n",
    "    n_samples = (n_samples // ROW_GROUP_SIZE) * ROW_GROUP_SIZE\n",
    "    X = X[:n_samples].rechunk((CHUNK_SIZE, -1))\n",
    "    obs_ = obs_.iloc[:n_samples].copy()\n",
    "    # add an index column to identifiy each sample\n",
    "    obs_['idx'] = np.arange(len(obs_), dtype='i8')\n",
    "    start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]\n",
    "    end_index = list(np.cumsum(X.chunks[0]))\n",
    "    # calculate divisons for dask dataframe\n",
    "    divisions = [0] + list(np.cumsum(X.chunks[0]))\n",
    "    divisions[-1] = divisions[-1] - 1\n",
    "    ddf = dd.from_delayed(\n",
    "        [\n",
    "            convert_to_dataframe(arr, start, end) for arr, start, end in \n",
    "            zip(X.to_delayed().flatten().tolist(), start_index, end_index)\n",
    "        ],\n",
    "        divisions=divisions\n",
    "    )\n",
    "    obs_dask = dd.from_pandas(obs_, chunksize=CHUNK_SIZE)\n",
    "    assert np.allclose(ddf.divisions, obs_dask.divisions)\n",
    "    ddf = dd.multi.concat([ddf, obs_dask], axis=1)\n",
    "\n",
    "    schema = pa.schema([\n",
    "        ('X', pa.list_(pa.float32())),\n",
    "        ('soma_joinid', pa.int64()),\n",
    "        ('is_primary_data', pa.bool_()),\n",
    "        ('dataset_id', pa.int64()),\n",
    "        ('donor_id', pa.int64()),\n",
    "        ('assay', pa.int64()),\n",
    "        ('cell_type', pa.int64()),\n",
    "        ('development_stage', pa.int64()),\n",
    "        ('disease', pa.int64()),\n",
    "        ('tissue', pa.int64()),\n",
    "        ('tissue_general', pa.int64()),\n",
    "        ('tech_sample', pa.int64()),\n",
    "        ('idx', pa.int64()),\n",
    "    ])\n",
    "    print(f'{split}: {X.shape[0]} cells')\n",
    "    ddf.to_parquet(\n",
    "        join(OUT_PATH, split), \n",
    "        engine='pyarrow',\n",
    "        schema=schema,\n",
    "        write_metadata_file=True,\n",
    "        row_group_size=ROW_GROUP_SIZE\n",
    "    )\n",
    "    \n",
    "    # free up memory\n",
    "    client.restart()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46c1696f-97d5-401d-b43d-36dd5b77bc48",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
